Conversation
willtebbutt
left a comment
There was a problem hiding this comment.
Approved subject to CI etc. I'd also like an approval from either @devmotion or @st-- before merging
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #499 +/- ##
==========================================
- Coverage 94.25% 92.83% -1.42%
==========================================
Files 52 52
Lines 1374 1396 +22
==========================================
+ Hits 1295 1296 +1
- Misses 79 100 +21 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
I'm a bit worried that this might cause performance regressions, or at least is suboptimal, based on the fixes that were needed in Distributions: JuliaStats/Distributions.jl#1492 |
|
Good point @devmotion . @theogf could you check some simple examples with Zygote? |
Then shouldn't I just reuse the |
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
…nelFunctions.jl into tgf/checkargs
|
@devmotion do you think that would be enough? |
|
Did you do compare performance with Zygote? |
|
Here are the benchmarks: using Zygote
using BenchmarkTools
using KernelFunctions
x = [3.0]
macro old_check_args(K, param, cond, desc=string(cond))
quote
if !($(esc(cond)))
throw(
ArgumentError(
string(
$(string(K)),
": ",
$(string(param)),
" = ",
$(esc(param)),
" does not ",
"satisfy the constraint ",
$(string(desc)),
".",
),
),
)
end
end
end
struct OldLinearKernel{Tc<:Real} <: KernelFunctions.SimpleKernel
c::Vector{Tc}
function OldLinearKernel(c::Real)
@old_check_args(LinearKernel, c, c >= zero(c), "c ≥ 0")
return new{typeof(c)}([c])
end
end
function f(x)
k = LinearKernel(;c = x[1])
sum(k.c)
end
function g(x)
k = LinearKernel(;c = x[1], check_args=false)
sum(k.c)
end
function h(x)
k = OldLinearKernel(x[1])
sum(k.c)
end
@btime Zygote.gradient($f, $x) # 15.980 μs (150 allocations: 5.89 KiB)
@btime Zygote.gradient($g, $x) # 13.853 μs (142 allocations: 5.72 KiB)
@btime Zygote.gradient($h, $x) # 4.700 μs (51 allocations: 1.89 KiB) |
|
That's a quite noticeable regression. Do we know what exactly causes it? |
|
For completeness I added the constructor function OldLinearKernel(c::Real; check_args=true)
check_args && @old_check_args(LinearKernel, c, c >= zero(c), "c ≥ 0")
return new{typeof(c)}([c])
endAnd these are the results function h(x)
k = OldLinearKernel(x[1])
sum(k.c)
end
function i(x)
k = OldLinearKernel(x[1]; check_args=false)
sum(k.c)
end
@btime Zygote.gradient($h, $x) # 6.086 μs (65 allocations: 2.58 KiB)
@btime Zygote.gradient($i, $x) # 10.404 μs (91 allocations: 3.38 KiB) |
|
Looking back at this, how about using |
Summary
Similarly to Distributions.jl we do not necessarily want to check for argument correctness all the time.
Proposed changes
This adds a
check_argskeyword to all constructors that require it to deactivate the check for the correctness for the arguments.Additionally this drops 1.3 for 1.6 because I am lazy to write
kwarg=kwarg(appears in 1.5 or 1.6 can't remember0What alternatives have you considered?
We could define a reparametrization of the parameters but we already discussed that it should be the duty of the user to do this.
Breaking changes
None! Crazy right?